import time

from pyspark.sql.types import DoubleType
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
import pyspark.ml.feature as feature
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
import pandas as pd

spark = SparkSession.builder.master("local").appName("ALs test").getOrCreate()
spark.sparkContext.setLogLevel("WARN")


# TODO ALS算法测试Demo，可不管
def ALs():
    url = "./ALS_score.csv"

    rdd = spark.read.csv(url, header=True)
    rdd.show(10, True)
    host_indexer = feature.StringIndexer() \
        .setInputCol("host") \
        .setOutputCol("hostIndex") \
        .setHandleInvalid("keep") \
        .fit(rdd)
    # 可以通过包含.setHandleInvalid（“keep”）来避免错误，并且我已经检查了输出,似乎没有任何问题，没有产生string => double 的问题
    # 这属于一种”不太可能发生“的情况，有点怪，担心不安全

    data_01 = host_indexer.transform(rdd)
    data_01 = data_01.withColumnRenamed("hostIndex", "hostId") \
        .withColumnRenamed("host", "hostName") \
        .withColumnRenamed("guest", "host")
    data_01 = host_indexer.transform(data_01) \
        .withColumnRenamed("hostIndex", "guestId") \
        .withColumnRenamed("host", "guestName")
    print(data_01.describe())
    # data_01.show(1000, True)
    data_01 = data_01.withColumn("attitude", col("attitude").cast(DoubleType()))
    # print(data_01.describe())
    trainData, testData = data_01.randomSplit([0.7, 0.3])

    timeOld = time.time()
    Als_recommend = ALS() \
        .setRegParam(0.01) \
        .setMaxIter(20) \
        .setItemCol("guestId") \
        .setRatingCol("attitude") \
        .setUserCol("hostId") \
        .setRank(12)
    modelExplicit = Als_recommend.fit(trainData)
    print("ALS时间花费为:{time}s".format(time=time.time() - timeOld))

    predictionExplicit = modelExplicit.transform(testData).na.drop()
    predictionExplicit.show(20, False)

    a = modelExplicit.recommendForAllUsers(5)
    # a.show(20, False)

    splitDF = a.select(col("hostId"),
                       col("recommendations").getItem(0).alias("top1"),
                       col("recommendations").getItem(1).alias("top2"),
                       col("recommendations").getItem(2).alias("top3"),
                       col("recommendations").getItem(3).alias("top4"),
                       col("recommendations").getItem(4).alias("top5"))

    splitDF.show(20, False)
    print(splitDF.printSchema())
    # splitDF = splitDF.withColumn("top1", col("top1").getField("guestId"))

    for i in range(1, 6):
        splitDF = splitDF.withColumn("top{}".format(i), col("top{}".format(i)).getField("guestId"))
    splitDF.show(20, False)

    return splitDF
    # evaluator = RegressionEvaluator() \
    #     .setLabelCol("attitude") \
    #     .setPredictionCol("prediction")
    # print("显性ALS的均方差为:", evaluator.evaluate(predictionExplicit))



if __name__ == '__main__':
    # TODO ALS算法demo
    ALs()
